#!/usr/bin/python

import code,pickle,sys,os,re
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from pylab import *
from optparse import OptionParser
import ocrolib
from ocrolib import dbtables,segrec,distcomp,docproc

parser = OptionParser("""
usage: %prog [options] chars.db cluster.db

Perform fast clustering of characters in a database using a fixed distance
measure.  The resulting cluster databases are often small enough to be 
labeled directly, or they can be clustered further using k-means.

This also updates the cluster id of the char in the original char.db to the id
of the corresponding cluster.
""")

parser.add_option("-D","--display",help="display chars",action="store_true")
parser.add_option("-v","--verbose",help="verbose output",action="store_true")
parser.add_option("-t","--table",help="table name",default="chars")
parser.add_option("-e","--epsilon",help="epsilon",type=float,default=0.1)
parser.add_option("-o","--overwrite",help="overwrite output if it exists",action="store_true")



class ScaledFE:
    """A feature extractor that only rescales the input image to fit into
    a 32x32 (or, generally, r x r box) and normalizes the vector.
    Parameters are r (size of the rescaled image), and normalize (can be
    one of "euclidean", "max", "sum", or None)."""
    def __init__(self,**kw):
        self.r = 32
        self.normalize = "euclidean"
        ocrolib.set_params(self,kw)
    def extract(self,image):
        v = array(docproc.isotropic_rescale(image,self.r),'f')
        if not hasattr(self,"normalize") or self.normalize is None:
            pass
        elif self.normalize=="euclidean":
            v /= sqrt(sum(v**2))
        elif self.normalize=="max":
            v /= amax(v)
        elif self.normalize=="sum":
            v /= sum(abs(v))
        return v

class DistComp:
    def __init__(self):
        self.data = None
        self.count = []
    def add(self,v):
        v = v.ravel()
        if self.data is None:
            self.data = v.reshape(1,len(v))
        else:
            self.data = concatenate((self.data,v.ravel().reshape(1,len(v))),axis=0)
        self.count.append(1.0)
    def distances(self,v):
        if self.data is None: return array([],'f')
        v = v.ravel()
        return array([norm(v-self.data[i]) for i in range(len(self.data))])
    def find(self,v,eps):
        if self.data is None: return -1
        ds = self.distances(v)
        i = argmin(ds)
        if ds[i]>eps: return -1
        return i
    def merge(self,i,v,weight):
        self.data[i,:] += v.ravel()*weight
    def length(self):
        return self.data.shape[0]
    def counts(self,i):
        return self.count[i]
    def vector(self,i):
        return self.data[i,:]
    def nearest(self,v):
        ds = self.distances(v)
        i = argmin(ds)
        return i

def test_DistComp():
    dc = DistComp()
    for i in range(33): dc.add(randn(17))
    print dc.find(dc.data[3],0.5)

class FastCluster:
    def __init__(self,eps=0.05):
        self.eps = eps
        self.ex = ScaledFE()
        self.dc = DistComp()
        self.classes = []
        self.counts = []
        self.total = 0
    def add(self,c,cls=None):
        self.total += 1
        c /= sqrt(sum(c**2))
        v = self.ex.extract(c)
        i = self.dc.find(v,self.eps)
        if i<0:
            self.dc.add(v)
            self.classes.append({cls:1})
            self.counts.append(1)
            return len(self.counts)-1
        else:
            self.classes[i][cls] = self.classes[i].get(cls,0)+1
            self.counts[i] += 1
            self.dc.merge(i,v,1.0/self.counts[i])
            return i
    def biniter(self):
        for i in range(self.dc.length()):
            key = ""
            v = self.dc.vector(i)
            count = self.dc.counts(i)
            yield i,v,count,key
    def cls(self,i):
        classes = list(self.classes[i].items())
        classes.sort(reverse=1,key=lambda x:x[1])
        # print i,self.classes[i],classes
        return classes[0]
    def stats(self):
        return " ".join([str(self.total),str(self.dc.length())])
    def save(self,file):
        table = dbtables.ClusterTable(file)
        table.create(image="blob",cls="text",count="integer",classes="text",cluster="integer")
        table.converter("image",dbtables.SmallImage())
        for i,v,count,key in self.biniter():
            image = array(v/amax(v)*255.0,'B')
            r = int(math.sqrt(image.size))
            assert r*r==image.size
            image.shape = (r,r)
            cls,count = self.cls(i)
            classes = repr(self.classes[i])
            table.set(image=image,cls=cls,count=count,classes=classes,cluster=i)

(options,args) = parser.parse_args()

if len(args)!=2:
    parser.print_help()
    sys.exit(0)

input = args[0]
output = args[1]
if os.path.exists(output):
    if not options.overwrite:
        sys.stderr.write("%s: already exists\n"%output)
        sys.exit(1)
    else:
        os.unlink(output)

ion()
show()

# open the relevant tables

table = dbtables.Table(input,options.table)
table.converter("image",dbtables.SmallImage())
table.create(image="blob",cluster="integer",cls="integer")

binned = FastCluster(options.epsilon)
total = 0
for row in table.get():
    # get the image and the class out of the record
    raw = row.image
    cls = row.cls

    # don't store images that are too large
    if raw.shape[0]>255 or raw.shape[1]>255: continue

    # make sure the maximum is 1.0
    raw = raw/float(amax(raw))

    # add it to the binned clusterer
    cluster = binned.add(raw,cls)

    # measure and report progress
    total+=1
    if total%1000==0: 
        print "#",total,"chars",binned.stats()

    # record which cluster the character was assigned to
    table.execute("update chars set cluster=? where id=?",[cluster,row.id])

table.commit()
table.close()

# FIXME optionally perform k-means clustering here so that we can do 
# everything in one step and keep the cluster labels updated more easily

# save the clustered data
print "#",binned.stats()
binned.save(output)
